{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "mnist = tf.keras.datasets.mnist\n",
    "(train_images,train_labels),(test_images,test_labels) = mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOR0lEQVR4nO3df6jUdb7H8df7titBrmF5klNK7l3OP7GQ2iC3jPXc9C4mkS1BKricS4XST5eMbrh/rJSBSNsSFEvuTdYTm9vSWorF7nbFiIVaG+WUVlzrhqHmjxFBkyLX9n3/ON+Wk53vZ8aZ78x39P18wDAz3/d8z/fdt159Z76f+c7H3F0Azn//UnYDADqDsANBEHYgCMIOBEHYgSC+08mNTZgwwadMmdLJTQKh7N27V0ePHrXRai2F3czmSnpS0gWS/tvdV6deP2XKFFWr1VY2CSChUqnk1pp+G29mF0h6WtKNkq6StMjMrmr27wFor1Y+s8+Q9JG7f+zupyT9XtL8YtoCULRWwn6FpH0jnu/Pln2DmS0xs6qZVWu1WgubA9CKtp+Nd/e17l5x90pPT0+7NwcgRythPyBp8ojnk7JlALpQK2F/W1KfmX3fzMZIWihpczFtASha00Nv7n7azO6V9GcND72tc/f3CusMQKFaGmd391clvVpQLwDaiK/LAkEQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxBER6dsxvlnx44dyfpTTz2VW1u/fn1y3YGBgWT9vvvuS9anT5+erEfDkR0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgmCcHUlDQ0PJ+pw5c5L1EydO5NbMLLnu4OBgsr5p06Zk/dixY8l6NC2F3cz2SvpM0leSTrt7pYimABSviCP7v7v70QL+DoA24jM7EESrYXdJfzGzHWa2ZLQXmNkSM6uaWbVWq7W4OQDNajXs17v7dEk3SrrHzH505gvcfa27V9y90tPT0+LmADSrpbC7+4Hs/oiklyTNKKIpAMVrOuxmdpGZfe/rx5J+LGl3UY0BKFYrZ+MnSnopGyv9jqTn3f1PhXSFjtm+fXuyfuuttybrx48fT9ZTY+njxo1LrjtmzJhk/ejR9CDQm2++mVu75pprWtr2uajpsLv7x5KuLrAXAG3E0BsQBGEHgiDsQBCEHQiCsANBcInreeDzzz/Pre3cuTO57uLFi5P1Tz/9tKmeGtHX15esP/TQQ8n6ggULkvWZM2fm1latWpVcd8WKFcn6uYgjOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EwTj7eWDp0qW5teeff76DnZydetM9nzx5MlmfNWtWsv7666/n1nbt2pVc93zEkR0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgmCc/RxQbzx6y5YtuTV3b2nb/f39yfpNN92UrD/44IO5tcsvvzy57rRp05L18ePHJ+vbtm3LrbW6X85FHNmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjG2bvA0NBQsj5nzpxk/cSJE7m11JTJkjRv3rxkfcOGDcl66ppxSXrsscdya3feeWdy3Z6enmT96qvTkwin/tlfeeWV5Lr1fm9/+vTpyXo3qntkN7N1ZnbEzHaPWHaJmb1mZh9m9+lvNwAoXSNv438rae4Zyx6WtNXd+yRtzZ4D6GJ1w+7ub0g6dsbi+ZLWZ4/XS7ql2LYAFK3ZE3QT3f1g9viQpIl5LzSzJWZWNbNqrVZrcnMAWtXy2XgfvqIg96oCd1/r7hV3r9Q74QKgfZoN+2Ez65Wk7P5IcS0BaIdmw75Z0kD2eEDSpmLaAdAudcfZzWyDpH5JE8xsv6RfSFot6Q9mdoekTyTd1s4mz3V79uxJ1tesWZOsHz9+PFlPfTzq7e1NrjswMJCsjx07Nlmvdz17vXpZUnPaS9Ljjz+erHfz7/HnqRt2d1+UU5pdcC8A2oivywJBEHYgCMIOBEHYgSAIOxAEl7gW4Msvv0zWUz+nLNW/3HLcuHHJ+uDgYG6tUqkk1/3iiy+S9aj27dtXdguF48gOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0Ewzl6Aej87XG8cvZ5Nm9I/FzBr1qyW/j5i4MgOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0Ewzl6ABx54IFkfnjQnX39/f7LOOHpz6u33dq3brTiyA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQjLM3aMuWLbm1oaGh5LpmlqzffPPNzbSEOlL7vd6/k6lTpxbcTfnqHtnNbJ2ZHTGz3SOWrTSzA2Y2lN3mtbdNAK1q5G38byXNHWX5r9x9anZ7tdi2ABStbtjd/Q1JxzrQC4A2auUE3b1m9m72Nn983ovMbImZVc2sWqvVWtgcgFY0G/ZfS/qBpKmSDkr6Zd4L3X2tu1fcvdLT09Pk5gC0qqmwu/thd//K3f8h6TeSZhTbFoCiNRV2M+sd8fQnknbnvRZAd6g7zm5mGyT1S5pgZvsl/UJSv5lNleSS9kpa2r4Wu0NqHvNTp04l173sssuS9QULFjTV0/mu3rz3K1eubPpvz549O1lfvXp103+7W9UNu7svGmXxs23oBUAb8XVZIAjCDgRB2IEgCDsQBGEHguAS1w648MILk/Xe3t5k/XxVb2ht1apVyfqaNWuS9cmTJ+fWli9fnlx37Nixyfq5iCM7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTBOHsHRP6p6NTPbNcbJ3/hhReS9fnz5yfrGzduTNaj4cgOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0Ewzt4gd2+qJkkvv/xysv7kk08201JXeOKJJ5L1Rx99NLd2/Pjx5LqLFy9O1gcHB5N1fBNHdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgnH2BplZUzVJOnToULJ+//33J+u33357sn7ppZfm1t56663kus8991yy/s477yTr+/btS9avvPLK3NrcuXOT6959993JOs5O3SO7mU02s21m9r6ZvWdmy7Lll5jZa2b2YXY/vv3tAmhWI2/jT0ta7u5XSfo3SfeY2VWSHpa01d37JG3NngPoUnXD7u4H3X1n9vgzSR9IukLSfEnrs5etl3RLm3oEUICzOkFnZlMkTZP0N0kT3f1gVjokaWLOOkvMrGpm1Vqt1kqvAFrQcNjNbKykP0r6mbufGFnz4StBRr0axN3XunvF3Ss9PT0tNQugeQ2F3cy+q+Gg/87dv/7JzsNm1pvVeyUdaU+LAIpQd+jNhseVnpX0gbuPvJ5xs6QBSauz+01t6fA8cPr06WT96aefTtZffPHFZP3iiy/Ore3Zsye5bquuu+66ZP2GG27IrT3yyCNFt4OERsbZZ0r6qaRdZjaULVuh4ZD/wczukPSJpNva0iGAQtQNu7v/VVLet0ZmF9sOgHbh67JAEIQdCIKwA0EQdiAIwg4EwSWuDbr22mtzazNmzEiuu3379pa2Xe8S2cOHDzf9tydMmJCsL1y4MFk/l38GOxqO7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBOPsDZo0aVJubePGjbk1SXrmmWeS9dS0xq1atmxZsn7XXXcl6319fUW2gxJxZAeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIGx4MpfOqFQqXq1WO7Y9IJpKpaJqtTrqr0FzZAeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIOqG3cwmm9k2M3vfzN4zs2XZ8pVmdsDMhrLbvPa3C6BZjfx4xWlJy919p5l9T9IOM3stq/3K3R9vX3sAitLI/OwHJR3MHn9mZh9IuqLdjQEo1ll9ZjezKZKmSfpbtuheM3vXzNaZ2ficdZaYWdXMqrVarbVuATSt4bCb2VhJf5T0M3c/IenXkn4gaaqGj/y/HG09d1/r7hV3r/T09LTeMYCmNBR2M/uuhoP+O3ffKEnuftjdv3L3f0j6jaT07IYAStXI2XiT9KykD9z9iRHLe0e87CeSdhffHoCiNHI2fqakn0raZWZD2bIVkhaZ2VRJLmmvpKVt6A9AQRo5G/9XSaNdH/tq8e0AaBe+QQcEQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiio1M2m1lN0icjFk2QdLRjDZydbu2tW/uS6K1ZRfZ2pbuP+vtvHQ37tzZuVnX3SmkNJHRrb93al0RvzepUb7yNB4Ig7EAQZYd9bcnbT+nW3rq1L4nemtWR3kr9zA6gc8o+sgPoEMIOBFFK2M1srpn9r5l9ZGYPl9FDHjPba2a7smmoqyX3ss7MjpjZ7hHLLjGz18zsw+x+1Dn2SuqtK6bxTkwzXuq+K3v6845/ZjezCyTtkfQfkvZLelvSInd/v6ON5DCzvZIq7l76FzDM7EeSTkoadPcfZsvWSDrm7quz/1GOd/f/6pLeVko6WfY03tlsRb0jpxmXdIuk/1SJ+y7R123qwH4r48g+Q9JH7v6xu5+S9HtJ80voo+u5+xuSjp2xeL6k9dnj9Rr+j6XjcnrrCu5+0N13Zo8/k/T1NOOl7rtEXx1RRtivkLRvxPP96q753l3SX8xsh5ktKbuZUUx094PZ40OSJpbZzCjqTuPdSWdMM941+66Z6c9bxQm6b7ve3adLulHSPdnb1a7kw5/BumnstKFpvDtllGnG/6nMfdfs9OetKiPsByRNHvF8UrasK7j7gez+iKSX1H1TUR/+egbd7P5Iyf38UzdN4z3aNOPqgn1X5vTnZYT9bUl9ZvZ9MxsjaaGkzSX08S1mdlF24kRmdpGkH6v7pqLeLGkgezwgaVOJvXxDt0zjnTfNuEred6VPf+7uHb9JmqfhM/L/J+nnZfSQ09e/Snonu71Xdm+SNmj4bd3fNXxu4w5Jl0raKulDSf8j6ZIu6u05SbskvavhYPWW1Nv1Gn6L/q6koew2r+x9l+irI/uNr8sCQXCCDgiCsANBEHYgCMIOBEHYgSAIOxAEYQeC+H8dj1XrNRdSpAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def plot_image(image):\n",
    "    plt.imshow(image.reshape(28,28),cmap=\"binary\")\n",
    "    plt.show()\n",
    "plot_image(train_images[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(48000, 784), dtype=float32, numpy=\n",
       "array([[0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_rate = 0.2\n",
    "len_train = int(len(train_images)*(1-valid_rate))\n",
    "x_train = train_images[:len_train]\n",
    "x_valid = train_images[len_train:]\n",
    "x_test = test_images\n",
    "\n",
    "y_train = train_labels[:len_train]\n",
    "y_valid = train_labels[len_train:]\n",
    "y_test = test_labels\n",
    "\n",
    "x_train = x_train.reshape(-1,784)\n",
    "x_valid = x_valid.reshape(-1,784)\n",
    "x_test = x_test.reshape(-1,784)\n",
    "\n",
    "x_train = tf.cast(x_train/255.0,dtype=tf.float32)\n",
    "x_valid = tf.cast(x_valid/255.0,dtype=tf.float32)\n",
    "x_test = tf.cast(x_test/255.0,dtype=tf.float32)\n",
    "\n",
    "\n",
    "x_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train = tf.one_hot(y_train,depth=10)\n",
    "y_valid = tf.one_hot(y_valid,depth=10)\n",
    "y_test = tf.one_hot(y_test,depth=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "Input_Dim = 784\n",
    "H1_NN = 64\n",
    "Output_Dim = 10\n",
    "\n",
    "W1 = tf.Variable(tf.random.normal([Input_Dim,H1_NN],mean=0.0,stddev=1.0,dtype=tf.float32))\n",
    "B1 = tf.Variable(tf.zeros([H1_NN]),dtype=tf.float32)\n",
    "\n",
    "W2 = tf.Variable(tf.random.normal([H1_NN,Output_Dim],mean=0.0,stddev=1.0,dtype=tf.float32))\n",
    "B2 = tf.Variable(tf.zeros([Output_Dim]),dtype=tf.float32)\n",
    "\n",
    "W = [W1,W2]\n",
    "B = [B1,B2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(x,w,b):\n",
    "    temp = tf.matmul(x,w[0])+b[0]\n",
    "    temp = tf.nn.relu(temp)\n",
    "    temp = tf.matmul(temp,w[1])+b[1]\n",
    "    return tf.nn.softmax(temp)\n",
    "\n",
    "def loss(x,w,b,y):\n",
    "    pred = model(x,w,b)\n",
    "    loss_ = tf.keras.losses.categorical_crossentropy(y_true=y,y_pred=pred)\n",
    "    return tf.reduce_mean(loss_)\n",
    "\n",
    "training_epochs = 20\n",
    "batch_size = 50\n",
    "learning_rate = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad(x,w,b,y):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss_ = loss(x,w,b,y)\n",
    "    return tape.gradient(loss_,[w[0],w[1],b[0],b[1]])\n",
    "\n",
    "def accuracy(w,x,b,y):\n",
    "    pred = model(x,w,b)\n",
    "    pred_ = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))\n",
    "    return tf.reduce_mean(tf.cast(pred_,tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 loss: 4.3852277 accuracy 0.71316665\n",
      "epoch: 2 loss: 3.9842672 accuracy 0.7406458\n",
      "epoch: 3 loss: 3.7971215 accuracy 0.752625\n",
      "epoch: 4 loss: 2.1256797 accuracy 0.85772914\n",
      "epoch: 5 loss: 2.1432323 accuracy 0.8555833\n",
      "epoch: 6 loss: 2.0809486 accuracy 0.85864586\n",
      "epoch: 7 loss: 1.9530882 accuracy 0.86997914\n",
      "epoch: 8 loss: 2.202441 accuracy 0.8470625\n",
      "epoch: 9 loss: 1.9232783 accuracy 0.8710833\n",
      "epoch: 10 loss: 1.933887 accuracy 0.8711875\n",
      "epoch: 11 loss: 1.875243 accuracy 0.87677085\n",
      "epoch: 12 loss: 0.5360361 accuracy 0.9576667\n",
      "epoch: 13 loss: 0.31598893 accuracy 0.9742917\n",
      "epoch: 14 loss: 0.35601285 accuracy 0.9695\n",
      "epoch: 15 loss: 0.38904598 accuracy 0.9676667\n",
      "epoch: 16 loss: 0.2568358 accuracy 0.9787083\n",
      "epoch: 17 loss: 0.3000349 accuracy 0.974375\n",
      "epoch: 18 loss: 0.31668785 accuracy 0.97425\n",
      "epoch: 19 loss: 0.3129103 accuracy 0.97495836\n",
      "epoch: 20 loss: 0.27527666 accuracy 0.97727084\n"
     ]
    }
   ],
   "source": [
    "optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)\n",
    "lost_list_of_train = []\n",
    "lost_list_of_valid = []\n",
    "accuracy_list_of_train = []\n",
    "accuracy_list_of_valid = []\n",
    "for epoch in range(training_epochs):\n",
    "    for step in range(int(len_train/batch_size)):\n",
    "        rx = x_train[step*batch_size:(step+1)*batch_size]\n",
    "        ry = y_train[step*batch_size:(step+1)*batch_size]\n",
    "        grad_ = grad(rx,W,B,ry)\n",
    "        optimizer.apply_gradients(zip(grad_,W+B))\n",
    "    loss_of_train = loss(x_train,W,B,y_train).numpy()\n",
    "    loss_of_valid = loss(x_valid,W,B,y_valid).numpy()\n",
    "    accuracy_of_train = accuracy(W,x_train,B,y_train).numpy()\n",
    "    accuracy_of_valid = accuracy(W,x_valid,B,y_valid).numpy()\n",
    "    lost_list_of_train.append(loss_of_train)\n",
    "    lost_list_of_valid.append(loss_of_valid)\n",
    "    accuracy_list_of_train.append(accuracy_of_train)\n",
    "    accuracy_list_of_valid.append(accuracy_of_valid)\n",
    "    print(\"epoch:\",(epoch+1),\"loss:\",loss_of_train,\"accuracy\",accuracy_of_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
